#include "decode_yolov8.h"

namespace ns_uestc_vhm {

__global__ void decode_yolov8_device_kernel(int batch_size, int num_class, int topK, float conf_thresh,
                                            float *src, int srcWidth, int srcHeight, int srcArea,
                                            float *dst, int dstWidth, int dstHeight, int dstArea) {
    int dx = blockDim.x * blockIdx.x + threadIdx.x;
    int dy = blockDim.y * blockIdx.y + threadIdx.y;
    if (dx >= srcHeight || dy >= batch_size) {
        return;
    }
    float *pitem = src + dy * srcArea + dx * srcWidth;
    float *class_confidence = pitem + 4;
    float confidence = *class_confidence++;
    int label = 0;
    for (int i = 1; i < num_class; ++i, ++class_confidence) {
        if (*class_confidence > confidence) {
            confidence = *class_confidence;
            label = i;
        }
    }
    if (confidence < conf_thresh) {
        return;
    }
    int index = atomicAdd(dst + dy * dstArea, 1);

    if (index >= topK) {
        return;
    }
    float cx = *pitem++;
    float cy = *pitem++;
    float width = *pitem++;
    float height = *pitem++;

    float left = cx - width * 0.5f;
    float top = cy - height * 0.5f;
    float right = cx + width * 0.5f;
    float bottom = cy + height * 0.5f;
    float *pout_item = dst + dy * dstArea + 1 + index * dstWidth;
    *pout_item++ = left;
    *pout_item++ = top;
    *pout_item++ = right;
    *pout_item++ = bottom;
    *pout_item++ = confidence;
    *pout_item++ = label;
    *pout_item++ = 1;
}

void yolov8::decodeDevice(InitParameter param, float *src, int srcWidth, int srcHeight, int srcArea, float *dst, int dstWidth, int dstHeight) {
    dim3 block_size(BLOCK_SIZE, BLOCK_SIZE);
    dim3 grid_size((srcHeight + BLOCK_SIZE - 1) / BLOCK_SIZE,
                   (param.batch_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
    int dstArea = 1 + dstWidth * dstHeight;

    decode_yolov8_device_kernel<<<grid_size, block_size, 0, nullptr>>>(param.batch_size, param.num_class, param.topK, param.conf_thresh,
                                                                       src, srcWidth, srcHeight, srcArea,
                                                                       dst, dstWidth, dstHeight, dstArea);
}

__global__ void transpose_device_kernel(int batch_size,
                                        float *src, int srcWidth, int srcHeight, int srcArea,
                                        float *dst, int dstWidth, int dstHeight, int dstArea) {
    int dx = blockDim.x * blockIdx.x + threadIdx.x;
    int dy = blockDim.y * blockIdx.y + threadIdx.y;
    if (dx >= dstHeight || dy >= batch_size) {
        return;
    }
    float *p_dst_row = dst + dy * dstArea + dx * dstWidth;
    float *p_src_col = src + dy * srcArea + dx;

    for (int i = 0; i < dstWidth; i++) {
        p_dst_row[i] = p_src_col[i * srcWidth];
    }
}

void yolov8::transposeDevice(InitParameter param,
                             float *src, int srcWidth, int srcHeight, int srcArea,
                             float *dst, int dstWidth, int dstHeight) {
    dim3 block_size(BLOCK_SIZE, BLOCK_SIZE);
    dim3 grid_size((dstHeight + BLOCK_SIZE - 1) / BLOCK_SIZE,
                   (param.batch_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
    int dstArea = dstWidth * dstHeight;

    transpose_device_kernel<<<grid_size, block_size, 0, nullptr>>>(param.batch_size,
                                                                   src, srcWidth, srcHeight, srcArea,
                                                                   dst, dstWidth, dstHeight, dstArea);
}

} // ns_uestc_vhm
